"""Define basic models and translate some torchvision stuff."""
"""Stuff from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py."""
import torch
import torchvision
import torch.nn as nn

from torchvision.models.resnet import Bottleneck
from .revnet import iRevNet
from .densenet import _DenseNet, _Bottleneck
from vit_timm import vit_small_patch16_224
from collections import OrderedDict
import numpy as np
from ..utils import set_random_seed
import sys
sys.path.append("/disk/Jiahui/47_2080ti/djc/gradient-inversion-bench - 20240129/inversefed/nn")
import resnetPreact



def construct_model(model, num_classes=10, seed=None, num_channels=3, modelkey=None):
    """Return various models."""
    if modelkey is None:
        if seed is None:
            model_init_seed = np.random.randint(0, 2**31 - 10)
        else:
            model_init_seed = seed
    else:
        model_init_seed = modelkey
    set_random_seed(model_init_seed)
    if model in ['ConvNet', 'ConvNet64']:
        model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes)
    elif model == "ConCNN":
        #False 3 0 18.67
        #False 1 0 12.18 
        use_bias = False#1:12.93 3:21.24
        use_relu = False#1:14.29 3: 17.23 
        use_dropout = False#1:12.18  3: 18.67
        use_maxpool = False#10.70
        kernel_size = 3
        padding = 0
        model = ConfigurableVGG11(num_classes=num_classes, use_bias=use_bias, use_relu=use_relu, use_dropout=use_dropout, use_maxpool=use_maxpool,
                          conv_kernel_size=kernel_size, conv_padding=padding)
    elif model=="LeNet":
        model =  LeNet(channel=3,num_classes=num_classes)
    elif model=="LeNetMnist":
        model =  LeNet(channel=1,num_classes=num_classes)
    elif model=="LeNetEMnist":
        model =  LeNet_em()
    elif model == 'SimpleCNN':
        model =  SimpleCNN()
    elif model =='SimpleCNNMnist':
        model = SimpleCNNMnist()
    elif model == "Alexnet":
        model = AlexNet(num_classes=10,num_channels=3)
        #model = torchvision.models.alexnet(num_classes=num_classes)
    elif model== "Vgg16":
        model=VGG(make_layers(cfg['D'], batch_norm=True),num_class=num_classes)
    elif model=="Vgg11":
        #from .cifar10_models.vgg import vgg11_bn
        #model=vgg11_bn(pretrained=False)
        model=VGG(make_layers(cfg['A'], batch_norm=True),num_class=num_classes)
        #model = torchvision.models.vgg11(pretrained=False,num_classes=num_classes)
    elif model=="Vgg13":
        #from .cifar10_models.vgg import vgg11_bn
        #model=vgg11_bn(pretrained=False)
        model=VGG(make_layers(cfg['B'], batch_norm=True),num_class=num_classes)
    elif model=="Vgg19":
        #from .cifar10_models.vgg import vgg19_bn
        #model=vgg19_bn(pretrained=False)
        model = torchvision.models.vgg19(pretrained=False,num_classes=num_classes)
    elif model == "AlexnetMnist":
        model = AlexNet(num_classes=10,num_channels=1)
        #model = torchvision.models.alexnet(num_classes=num_classes)
    elif model=="resnetPreact":
        model = resnetPreact.Network(num_classes=num_classes)
    elif model == 'ConvNet8':
        model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes)
    elif model == 'ConvNet16':
        model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes)
    elif model == 'ConvNet32':
        model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes)
    elif model == 'BeyondInferringMNIST':
        model = torch.nn.Sequential(OrderedDict([
            ('conv1', torch.nn.Conv2d(1, 32, 3, stride=2, padding=1)),
            ('relu0', torch.nn.LeakyReLU()),
            ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)),
            ('relu1', torch.nn.LeakyReLU()),
            ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)),
            ('relu2', torch.nn.LeakyReLU()),
            ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)),
            ('relu3', torch.nn.LeakyReLU()),
            ('flatt', torch.nn.Flatten()),
            ('linear0', torch.nn.Linear(12544, 12544)),
            ('relu4', torch.nn.LeakyReLU()),
            ('linear1', torch.nn.Linear(12544, 10)),
            ('softmax', torch.nn.Softmax(dim=1))
        ]))
    elif model == 'BeyondInferringCifar':
        model = torch.nn.Sequential(OrderedDict([
            ('conv1', torch.nn.Conv2d(3, 32, 3, stride=2, padding=1)),
            ('relu0', torch.nn.LeakyReLU()),
            ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)),
            ('relu1', torch.nn.LeakyReLU()),
            ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)),
            ('relu2', torch.nn.LeakyReLU()),
            ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)),
            ('relu3', torch.nn.LeakyReLU()),
            ('flatt', torch.nn.Flatten()),
            ('linear0', torch.nn.Linear(12544, 12544)),
            ('relu4', torch.nn.LeakyReLU()),
            ('linear1', torch.nn.Linear(12544, 10)),
            ('softmax', torch.nn.Softmax(dim=1))
        ]))
    elif model == 'MLP':
        width = 1024
        model = torch.nn.Sequential(OrderedDict([
            ('flatten', torch.nn.Flatten()),
            ('linear0', torch.nn.Linear(3072, width)),
            ('relu0', torch.nn.ReLU()),
            ('linear1', torch.nn.Linear(width, width)),
            ('relu1', torch.nn.ReLU()),
            ('linear2', torch.nn.Linear(width, width)),
            ('relu2', torch.nn.ReLU()),
            ('linear3', torch.nn.Linear(width, num_classes))]))
    elif model == 'TwoLP':
        width = 2048
        model = torch.nn.Sequential(OrderedDict([
            ('flatten', torch.nn.Flatten()),
            ('linear0', torch.nn.Linear(3072, width)),
            ('relu0', torch.nn.ReLU()),
            ('linear3', torch.nn.Linear(width, num_classes))]))
    elif model == 'ResNet20':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16)
    elif model == 'ResNet20-nostride':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16,
                       strides=[1, 1, 1, 1])
    elif model == 'ResNet20-10':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 10)
    elif model == 'ResNet20-4':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4)
    elif model == 'ResNet20-4-unpooled':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4,
                       pool='max')
    elif model == 'ResNet28-10':
        model = ResNet(torchvision.models.resnet.BasicBlock, [4, 4, 4], num_classes=num_classes, base_width=16 * 10)
    elif model == 'ResNet32':
        model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16)
        #model = ResNet(BasicBlock, [5,5,5],num_classes=num_classes)
    elif model == 'ResNet32-10':
        model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16 * 10)
    elif model == 'ResNet44':
        model = ResNet(torchvision.models.resnet.BasicBlock, [7, 7, 7], num_classes=num_classes, base_width=16)
    elif model == 'ResNet56':
        model = ResNet(torchvision.models.resnet.BasicBlock, [9, 9, 9], num_classes=num_classes, base_width=16)
    elif model == 'ResNet110':
        model = ResNet(torchvision.models.resnet.BasicBlock, [18, 18, 18], num_classes=num_classes, base_width=16)
    elif model == 'ResNet18':
        from .cifar10_models.resnet import resnet18
        model=resnet18(pretrained=False)

        #model =ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes)#this

        #model = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes, base_width=64)
        
    elif model == 'ResNet34':
        model = ResNet(BasicBlock, [3,4,6,3],num_classes=num_classes)
        #model = ResNet(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], num_classes=num_classes, base_width=64)
    elif model == 'ResNet50':
        from .cifar10_models.resnet import resnet50
        model=resnet50(pretrained=False,num_classes=num_classes)
    elif model == 'ResNet50-2':
        model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64 * 2)
    elif model == 'ResNet101':
        model = ResNet(Bottleneck, [3,4,23,3],num_classes=num_classes)
        #model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], num_classes=num_classes, base_width=64)
    elif model == 'ResNet152':
        from .cifar10_models.resnet import resnet152
        model=torchvision.models.resnet152(pretrained=False)
    elif model == 'DenseNet121':
        from .cifar10_models.densenet import densenet121
        model=densenet121(pretrained=False)
    elif model == 'GoogleNet':
        from .cifar10_models.googlenet import googlenet
        model=googlenet(pretrained=False)
    elif model ==  "Inception_v3":
        from .cifar10_models.inception import inception_v3
        model=inception_v3(pretrained=False)
    elif model == 'MobileNet':
        inverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 1],  # cifar adaptation, cf.https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]
        model = torchvision.models.MobileNetV2(num_classes=num_classes,
                                               inverted_residual_setting=inverted_residual_setting,
                                               width_mult=1.0)
        model.features[0] = torchvision.models.mobilenet.ConvBNReLU(num_channels, 32, stride=1)  # this is fixed to width=1
    elif model == 'MNASNet':
        model = torchvision.models.MNASNet(1.0, num_classes=num_classes, dropout=0.2)
    #elif model == 'DenseNet121':
    #    model = torchvision.models.DenseNet(growth_rate=32, block_config=(6, 12, 24, 16),
    #                                        num_init_features=64, bn_size=4, drop_rate=0, num_classes=num_classes,
    #                                        memory_efficient=False)
    elif model == 'DenseNet40':
        model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12, num_classes=num_classes)
    elif model == 'DenseNet40-4':
        model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12 * 4, num_classes=num_classes)
    elif model == 'SRNet3':
        model = SRNet(upscale_factor=3, num_channels=num_channels)
    elif model == 'SRNet1':
        model = SRNet(upscale_factor=1, num_channels=num_channels)
    elif model == 'iRevNet':
        if num_classes <= 100:
            in_shape = [num_channels, 32, 32]  # only for cifar right now
            model = iRevNet(nBlocks=[18, 18, 18], nStrides=[1, 2, 2],
                            nChannels=[16, 64, 256], nClasses=num_classes,
                            init_ds=0, dropout_rate=0.1, affineBN=True,
                            in_shape=in_shape, mult=4)
        else:
            in_shape = [3, 224, 224]  # only for imagenet
            model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2],
                            nChannels=[24, 96, 384, 1536], nClasses=num_classes,
                            init_ds=2, dropout_rate=0.1, affineBN=True,
                            in_shape=in_shape, mult=4)
    elif model == 'LeNetZhu':
        model = LeNetZhu(num_channels=num_channels, num_classes=num_classes)
 
    elif model =='vit':
        from vit_pytorch import ViT
        model= ViT(
                image_size = 32,
                patch_size = 8,
                num_classes = num_classes,
                dim = 1024,
                depth = 4,
                heads = 6,
                mlp_dim = 1024,
        )
    elif model == 'swin':
        from swin_transformer import SwinTransformer
        model = SwinTransformer(img_size=32,
                            patch_size=4,
                            in_chans=3,
                            num_classes=num_classes,
                            embed_dim=96,
                            depths=[2, 2, 6, 2],
                            num_heads=[3, 6, 12, 24],
                            window_size=4,
                            mlp_ratio=4.,
                            qkv_bias=True,
                            qk_scale=None,
                            drop_rate=0,
                            drop_path_rate=0.1,
                            ape=False,
                            patch_norm=True)
        model.head = torch.nn.Linear(model.head.weight.shape[1], num_classes)
    else:
        raise NotImplementedError('Model not implemented.')

    print(f'Model initialized with random key {model_init_seed}.')
    return model, model_init_seed

class AlexNet(nn.Module):
    def __init__(self, num_classes=10, num_channels=3):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x
''''''
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)


        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)

            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        #TODO
        out = out + self.shortcut(x).clone()
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride !=1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        #TODO
        out += self.shortcut(x).clone()
        out = F.relu(out)
        return out




class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride ))
            self.in_planes = planes*block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return F.log_softmax(out, dim=1)


import torch.nn.functional as F
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3,   64,  3)
        self.conv2 = nn.Conv2d(64,  128, 3)
        self.conv3 = nn.Conv2d(128, 256, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

class SimpleCNNMnist(nn.Module):
    def __init__(self):
        super(SimpleCNNMnist, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2) 
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)   
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)  
        
        self.fc = nn.Linear(64 * 6 * 6, 10) 
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = x.view(-1, 64 * 6 * 6)  
        x = self.fc(x)
        
        return x

class ConvNet(torch.nn.Module):
    """ConvNetBN."""

    def __init__(self, width=32, num_classes=10, num_channels=3):
        """Init with width and num classes."""
        super().__init__()
        self.model = torch.nn.Sequential(OrderedDict([
            ('conv0', torch.nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
            ('bn0', torch.nn.BatchNorm2d(1 * width)),
            ('relu0', torch.nn.ReLU()),

            ('conv1', torch.nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
            ('bn1', torch.nn.BatchNorm2d(2 * width)),
            ('relu1', torch.nn.ReLU()),

            ('conv2', torch.nn.Conv2d(2 * width, 2 * width, kernel_size=3, padding=1)),
            ('bn2', torch.nn.BatchNorm2d(2 * width)),
            ('relu2', torch.nn.ReLU()),

            ('conv3', torch.nn.Conv2d(2 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn3', torch.nn.BatchNorm2d(4 * width)),
            ('relu3', torch.nn.ReLU()),

            ('conv4', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn4', torch.nn.BatchNorm2d(4 * width)),
            ('relu4', torch.nn.ReLU()),

            ('conv5', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn5', torch.nn.BatchNorm2d(4 * width)),
            ('relu5', torch.nn.ReLU()),

            ('pool0', torch.nn.MaxPool2d(3)),

            ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn6', torch.nn.BatchNorm2d(4 * width)),
            ('relu6', torch.nn.ReLU()),

            ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn6', torch.nn.BatchNorm2d(4 * width)),
            ('relu6', torch.nn.ReLU()),

            ('conv7', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn7', torch.nn.BatchNorm2d(4 * width)),
            ('relu7', torch.nn.ReLU()),

            ('pool1', torch.nn.MaxPool2d(3)),
            ('flatten', torch.nn.Flatten()),
            ('linear', torch.nn.Linear(36 * width, num_classes))
        ]))

    def forward(self, input):
        return self.model(input)
class LeNet(nn.Module):
    def __init__(self, channel=1, hidden=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid

        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act()
        )

        self.fc = nn.Sequential(
            nn.Linear(hidden, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        #print("out的size:",out.shape)
        out = out.view(-1, 768)
        out = self.fc(out)
        return out

class LeNet_em(nn.Module):
    def __init__(self):
        super(LeNet_em, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 27)  # 27 classes for EMNIST letters

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.fc3(x)
        return x
class LeNetZhu(nn.Module):
    """LeNet variant from https://github.com/mit-han-lab/dlg/blob/master/models/vision.py."""

    def __init__(self, num_classes=10, num_channels=3):
        """3-Layer sigmoid Conv with large linear layer."""
        super().__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(num_channels, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(768, num_classes)
        )
        for module in self.modules():
            self.weights_init(module)

    @staticmethod
    def weights_init(m):
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        # print(out.size())
        out = self.fc(out)
        return out

class ConfigurableCNN(nn.Module):
    def __init__(self, use_bias, use_relu, use_dropout, use_maxpool, kernel_size, padding):
        super(ConfigurableCNN, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=kernel_size, padding=padding, bias=use_bias),
            nn.ReLU() if use_relu else nn.Identity(),
            #nn.MaxPool2d(kernel_size=2, stride=2) if use_maxpool else nn.Identity(),
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=kernel_size, padding=padding, bias=use_bias),
            nn.ReLU() if use_relu else nn.Identity(),
            #nn.MaxPool2d(kernel_size=2, stride=2) if use_maxpool else nn.Identity(),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=kernel_size, padding=padding, bias=use_bias),
            nn.ReLU() if use_relu else nn.Identity(),
            #nn.MaxPool2d(kernel_size=2, stride=2) if use_maxpool else nn.Identity(),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=kernel_size, padding=padding, bias=use_bias),
            nn.ReLU() if use_relu else nn.Identity(),
            nn.MaxPool2d(kernel_size=2, stride=2) if use_maxpool else nn.Identity()
        )
        
        # Calculate the number of features after convolutional layers
        self.num_conv_features = self._calculate_conv_output_size()

        self.fc_layers = nn.Sequential(
            nn.Linear(self.num_conv_features, 256, bias=use_bias),
            nn.ReLU() if use_relu else nn.Identity(),
            nn.Dropout(p=0.5) if use_dropout else nn.Identity(),
            nn.Linear(256, 128, bias=use_bias),
            nn.ReLU() if use_relu else nn.Identity(),
            nn.Dropout(p=0.5) if use_dropout else nn.Identity(),
            nn.Linear(128, 10, bias=use_bias)  # 10 classes for CIFAR-10
        )

    def _calculate_conv_output_size(self):
        # Dummy input to calculate the output size of conv layers
        dummy_input = torch.zeros(4, 3, 32, 32)
        dummy_output = self.conv_layers(dummy_input)
        return dummy_output.view(dummy_output.size(0), -1).size(1)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x


class ConvNet_Config(nn.Module):
    def __init__(self, use_bias, use_relu, use_dropout, use_maxpool, kernel_size, padding, width=64, num_classes=10, num_channels=3):
        super(ConvNet_Config, self).__init__()

        layers = []

        # Add convolutional layers
        in_channels = num_channels
        k=[1,2,2,4,4,4]
        for i in range(8):
            out_channels = width * k[i] if i < 6 else width * 4
            layers.append(('conv' + str(i), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=use_bias)))
            layers.append(('relu' + str(i), nn.ReLU() if use_relu else nn.Identity()))
            in_channels = out_channels

            if use_maxpool and i in [5, 7]:
                layers.append(('pool' + str(i), nn.MaxPool2d(3)))
        
        self.conv_layers = nn.Sequential(OrderedDict(layers))

        self.dropout1 = nn.Dropout(p=0.5) if use_dropout else nn.Identity()

        self.linear1 = nn.Linear(self._calculate_conv_output_size(num_channels, width), 10, bias=use_bias)
        #self.relu_linear1 = nn.ReLU() if use_relu else nn.Identity()

        #self.dropout2 = nn.Dropout(p=0.5) if use_dropout else nn.Identity()

        #self.linear2 = nn.Linear(512, num_classes, bias=use_bias)

    def _calculate_conv_output_size(self, num_channels, width):
        dummy_input = torch.zeros(1, num_channels, 32, 32)
        dummy_output = self.conv_layers(dummy_input)
        return dummy_output.view(dummy_output.size(0), -1).size(1)

    def forward(self, input):
        x = self.conv_layers(input)
        x = x.view(x.size(0), -1)
        x = self.dropout1(x)
        x = self.linear1(x)
        #x = self.relu_linear1(x)
        #x = self.dropout2(x)
        #x = self.linear2(x)
        return x
class CustomBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bias=True, use_relu=True, use_dropout=True):
        super(CustomBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU() if use_relu else nn.Identity()
        self.dropout = nn.Dropout(0.5) if use_dropout else nn.Identity()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=bias)
        self.bn2 = nn.BatchNorm2d(planes)
        self.stride = stride

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.stride != 1:
            x = self.conv1(x)
            x = self.bn1(x)

        out += x
        out = self.relu(out)
        out = self.dropout(out)
        return out

class CustomResNet(nn.Module):
    def __init__(self, num_classes=10, bias=True, use_relu=True, use_dropout=True, kernel_size='medium', padding='same'):
        super(CustomResNet, self).__init__()
        self.in_planes = 64

        if kernel_size == 'small':
            conv1_kernel_size = 1
        elif kernel_size == 'large':
            conv1_kernel_size = 5
        else:
            conv1_kernel_size = 3
        
        if padding == 'none':
            conv1_padding = 0
        elif padding == 'more':
            conv1_padding = 2
        else:
            conv1_padding = 1

        self.conv1 = nn.Conv2d(3, 64, kernel_size=conv1_kernel_size, stride=1, padding=conv1_padding, bias=bias)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU() if use_relu else nn.Identity()
        self.dropout = nn.Dropout(0.5) if use_dropout else nn.Identity()

        self.layer1 = self._make_layer(64, 3, stride=1, bias=bias, use_relu=use_relu, use_dropout=use_dropout)
        self.layer2 = self._make_layer(128, 4, stride=2, bias=bias, use_relu=use_relu, use_dropout=use_dropout)
        self.layer3 = self._make_layer(256, 6, stride=2, bias=bias, use_relu=use_relu, use_dropout=use_dropout)
        self.layer4 = self._make_layer(512, 3, stride=2, bias=bias, use_relu=use_relu, use_dropout=use_dropout)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * CustomBlock.expansion, num_classes)

    def _make_layer(self, planes, num_blocks, stride, bias, use_relu, use_dropout):
        layers = []
        layers.append(CustomBlock(self.in_planes, planes, stride, bias, use_relu, use_dropout))
        self.in_planes = planes * CustomBlock.expansion
        for _ in range(1, num_blocks):
            layers.append(CustomBlock(self.in_planes, planes, bias=bias, use_relu=use_relu, use_dropout=use_dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


class ConfigurableVGG11(nn.Module):
    def __init__(self, num_classes=1000, use_bias=True, use_relu=True, use_dropout=True, use_maxpool=True,
                 conv_kernel_size=3, conv_padding=1):
        super(ConfigurableVGG11, self).__init__()

        layers = []
        in_channels = 3  # Input channels for the first layer
        cfg = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']  # Configuration of VGG11

        for v in cfg:
            if v == 'M':
                if use_maxpool:
                    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=conv_kernel_size, padding=conv_padding, bias=use_bias)
                layers.append(conv2d)
                if use_relu:
                    layers.append(nn.ReLU(inplace=True))
                if use_dropout:
                    layers.append(nn.Dropout())
                in_channels = v

        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096, bias=use_bias),
            nn.ReLU() if use_relu else nn.Identity(),
            nn.Dropout(0.5) if use_dropout else nn.Identity(),
            nn.Linear(4096, 4096, bias=use_bias),
            nn.ReLU() if use_relu else nn.Identity(),
            nn.Dropout(0.5) if use_dropout else nn.Identity(),
            nn.Linear(4096, num_classes, bias=use_bias)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
cfg = {
    'A' : [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'B' : [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],
    'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

class VGG(nn.Module):

    def __init__(self, features, num_class=100):
        super().__init__()
        self.features = features

        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(4096, num_class)
        )

    def forward(self, x):
        output = self.features(x)
        output = output.view(output.size()[0], -1)
        output = self.classifier(output)

        return output

def make_layers(cfg, batch_norm=False):
    layers = []

    input_channel = 3
    for l in cfg:
        if l == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            continue

        layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]

        if batch_norm:
            layers += [nn.BatchNorm2d(l)]

        layers += [nn.ReLU(inplace=False)]
        input_channel = l

    return nn.Sequential(*layers)

def vgg11_bn():
    return VGG(make_layers(cfg['A'], batch_norm=True))

def vgg13_bn():
    return VGG(make_layers(cfg['B'], batch_norm=True))

def vgg16_bn():
    return VGG(make_layers(cfg['D'], batch_norm=True))

def vgg19_bn():
    return VGG(make_layers(cfg['E'], batch_norm=True))

class ViT_cifar(nn.Module):
    def __init__(self, in_c:int=3, num_classes:int=10, img_size:int=32, patch:int=8, dropout:float=0., num_layers:int=7, hidden:int=384, mlp_hidden:int=384*4, head:int=8, is_cls_token:bool=True):
        super(ViT_cifar, self).__init__()
        # hidden=384

        self.patch = patch # number of patches in one row(or col)
        self.is_cls_token = is_cls_token
        self.patch_size = img_size//self.patch
        f = (img_size//self.patch)**2*3 # 48 # patch vec length
        num_tokens = (self.patch**2)+1 if self.is_cls_token else (self.patch**2)

        self.emb = nn.Linear(f, hidden) # (b, n, f)
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden)) if is_cls_token else None
        self.pos_emb = nn.Parameter(torch.randn(1,num_tokens, hidden))
        enc_list = [TransformerEncoder(hidden,mlp_hidden=mlp_hidden, dropout=dropout, head=head) for _ in range(num_layers)]
        self.enc = nn.Sequential(*enc_list)
        self.fc = nn.Sequential(
            nn.LayerNorm(hidden),
            nn.Linear(hidden, num_classes) # for cls_token
        )


    def forward(self, x):
        out = self._to_words(x)
        out = self.emb(out)
        if self.is_cls_token:
            out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out],dim=1)
        out = out + self.pos_emb
        out = self.enc(out)
        if self.is_cls_token:
            out = out[:,0]
        else:
            out = out.mean(1)
        out = self.fc(out)
        return out

    def _to_words(self, x):
        """
        (b, c, h, w) -> (b, n, f)
        """
        out = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size).permute(0,2,3,4,5,1)
        out = out.reshape(x.size(0), self.patch**2 ,-1)
        return out

class TransformerEncoder(nn.Module):
    def __init__(self, feats:int, mlp_hidden:int, head:int=8, dropout:float=0.):
        super(TransformerEncoder, self).__init__()
        self.la1 = nn.LayerNorm(feats)
        self.msa = MultiHeadSelfAttention(feats, head=head, dropout=dropout)
        self.la2 = nn.LayerNorm(feats)
        self.mlp = nn.Sequential(
            nn.Linear(feats, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, feats),
            nn.GELU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        out = self.msa(self.la1(x)) + x
        out = self.mlp(self.la2(out)) + out
        return out


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, feats:int, head:int=8, dropout:float=0.):
        super(MultiHeadSelfAttention, self).__init__()
        self.head = head
        self.feats = feats
        self.sqrt_d = self.feats**0.5

        self.q = nn.Linear(feats, feats)
        self.k = nn.Linear(feats, feats)
        self.v = nn.Linear(feats, feats)

        self.o = nn.Linear(feats, feats)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, n, f = x.size()
        q = self.q(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        k = self.k(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        v = self.v(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)

        score = F.softmax(torch.einsum("bhif, bhjf->bhij", q, k)/self.sqrt_d, dim=-1) #(b,h,n,n)
        attn = torch.einsum("bhij, bhjf->bihf", score, v) #(b,n,h,f//h)
        o = self.dropout(self.o(attn.flatten(2)))
        return o

class MultiHeadDepthwiseSelfAttention(nn.Module):
    def __init__(self, feats:int, head:int=8, dropout:float=0):
        super(MultiHeadDepthwiseSelfAttention, self).__init__()
        ...

    def forward(self, x):
        ...